{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "14a172d1-597f-4251-94cf-816d476cf6b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Taichi] version 1.4.1, llvm 15.0.4, commit e67c674e, linux, python 3.9.16\n",
      "[I 03/29/23 20:11:03.458 2437593] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout\n",
      "[Taichi] Starting on arch=cuda\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import taichi as ti\n",
    "import torch\n",
    "from taichi.math import uvec3\n",
    "\n",
    "taichi_init_args = {\"arch\": ti.cuda, \"device_memory_GB\": 4.0}\n",
    "ti.init(**taichi_init_args)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3aa17a75-66c8-4c76-88c3-3d364ecf8219",
   "metadata": {},
   "outputs": [],
   "source": [
    "@ti.kernel\n",
    "def torch2ti(field: ti.template(), data: ti.types.ndarray()):\n",
    "    for I in ti.grouped(data):\n",
    "        field[I] = data[I]\n",
    "\n",
    "\n",
    "@ti.kernel\n",
    "def ti2torch(field: ti.template(), data: ti.types.ndarray()):\n",
    "    for I in ti.grouped(data):\n",
    "        data[I] = field[I]\n",
    "\n",
    "\n",
    "@ti.kernel\n",
    "def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):\n",
    "    for I in ti.grouped(grad):\n",
    "        grad[I] = field.grad[I]\n",
    "\n",
    "\n",
    "@ti.kernel\n",
    "def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):\n",
    "    for I in ti.grouped(grad):\n",
    "        field.grad[I] = grad[I]\n",
    "        \n",
    "@ti.kernel\n",
    "def random_initialize(data: ti.types.ndarray()):\n",
    "    for I in ti.grouped(data):\n",
    "        data[I] = (ti.random() * 2.0 - 1.0) * 1e-4\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e3892b1e-2743-4a53-a4d1-21c51d0c1343",
   "metadata": {},
   "outputs": [],
   "source": [
    "@ti.kernel\n",
    "def hash_kernel(\n",
    "        xyzs: ti.template(), table: ti.template(),\n",
    "        xyzs_embedding: ti.template(), B: ti.i32):\n",
    "\n",
    "    # get hash table embedding\n",
    "    for i in ti.ndrange(B):\n",
    "        xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])\n",
    "        resolution = 512\n",
    "\n",
    "        pos = xyz * (resolution-1) + 0.5 \n",
    "        pos_grid = ti.cast(ti.floor(pos), ti.uint32)\n",
    "        pos -= pos_grid\n",
    "\n",
    "        map_size = 2**19\n",
    "\n",
    "        local_feature_0 = 0.0\n",
    "        local_feature_1 = 0.0\n",
    "\n",
    "        for idx in ti.static(range(8)):\n",
    "            w = 1.\n",
    "            pos_grid_local = uvec3(0)\n",
    "\n",
    "            # Linear interpolation\n",
    "            for d in ti.static(range(3)):\n",
    "                if (idx & (1 << d)) == 0:\n",
    "                    pos_grid_local[d] = pos_grid[d]\n",
    "                    w *= 1 - pos[d]\n",
    "                else:\n",
    "                    pos_grid_local[d] = pos_grid[d] + 1\n",
    "                    w *= pos[d]\n",
    "                    \n",
    "            # Hash\n",
    "            _hash_index = ti.uint32(0)\n",
    "            primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))\n",
    "            for i in ti.static(range(3)):\n",
    "                _hash_index ^= ti.uint32(pos_grid_local[i]) * primes[i] # add randomness\n",
    "            index = _hash_index % map_size\n",
    "\n",
    "            index_table = ti.cast(index * 2, ti.int32) # each position consists of two elements\n",
    "            local_feature_0 += w * table[index_table]\n",
    "            local_feature_1 += w * table[index_table + 1]\n",
    "\n",
    "        xyzs_embedding[i, 0] = local_feature_0\n",
    "        xyzs_embedding[i, 1] = local_feature_1\n",
    " \n",
    "\n",
    "class ToyHashEncoder(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, batch_size=8192, out_dim=2):\n",
    "        super(ToyHashEncoder, self).__init__()\n",
    "\n",
    "        self.out_dim = out_dim\n",
    "        self.total_hash_size = 2**19\n",
    "        print(\"total_hash_size: \", self.total_hash_size)\n",
    "\n",
    "        self.hash_table = torch.nn.Parameter(\n",
    "            torch.zeros(self.total_hash_size, \n",
    "            dtype=torch.float32),\n",
    "            requires_grad=True\n",
    "        )\n",
    "        \n",
    "        random_initialize(self.hash_table)\n",
    "\n",
    "        self.parameter_fields = ti.field(dtype=ti.f32,\n",
    "                                         shape=(self.total_hash_size, ),\n",
    "                                         needs_grad=True)\n",
    "        \n",
    "        self.input_fields = ti.field(dtype=ti.f32,\n",
    "                                     shape=(batch_size * 1024, 3),\n",
    "                                     needs_grad=True)\n",
    "    \n",
    "        self.output_fields = ti.field(dtype=ti.f32,\n",
    "                                      shape=(batch_size * 1024, self.out_dim),\n",
    "                                      needs_grad=True)\n",
    "        \n",
    "        \n",
    "        self._hash_kernel = hash_kernel\n",
    "\n",
    "        class _module_function(torch.autograd.Function):\n",
    "\n",
    "            @staticmethod\n",
    "            def forward(ctx, input_pos, params):\n",
    "                ctx.input_size = input_pos.shape\n",
    "                \n",
    "                output_embedding = torch.zeros(input_pos.shape[0], self.out_dim,\n",
    "                                               dtype=input_pos.dtype,\n",
    "                                               device=input_pos.device)\n",
    "                \n",
    "                torch2ti(self.input_fields, input_pos.contiguous())\n",
    "                torch2ti(self.parameter_fields, params.contiguous())\n",
    "\n",
    "                self._hash_kernel(\n",
    "                    self.input_fields,\n",
    "                    self.parameter_fields,\n",
    "                    self.output_fields, # output\n",
    "                    input_pos.shape[0], \n",
    "                )\n",
    "                ti2torch(self.output_fields, output_embedding)\n",
    "\n",
    "                return output_embedding\n",
    "\n",
    "            @staticmethod\n",
    "            def backward(ctx, doutput):\n",
    "                input_size = ctx.input_size\n",
    "                \n",
    "                self.zero_grad()\n",
    "                \n",
    "                hash_grad = torch.zeros(self.total_hash_size,\n",
    "                                        dtype=doutput.dtype,\n",
    "                                        device=doutput.device)\n",
    "                \n",
    "                input_grad = torch.zeros(*input_size,\n",
    "                                         dtype=doutput.dtype,\n",
    "                                         device=doutput.device)\n",
    "\n",
    "                torch2ti_grad(self.output_fields, doutput.contiguous())\n",
    "                self._hash_kernel.grad(\n",
    "                    self.input_fields,\n",
    "                    self.parameter_fields,\n",
    "                    self.output_fields,\n",
    "                    doutput.shape[0],\n",
    "                )\n",
    "                ti2torch_grad(self.parameter_fields, hash_grad)\n",
    "                ti2torch_grad(self.input_fields, input_grad)\n",
    "                return input_grad, hash_grad\n",
    "\n",
    "        self._module_function = _module_function\n",
    "\n",
    "    def zero_grad(self):\n",
    "        self.input_fields.grad.fill(0.)\n",
    "        self.parameter_fields.grad.fill(0.)\n",
    "\n",
    "    def forward(self, positions):\n",
    "        return self._module_function.apply(positions, self.hash_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5f6390ef-d51e-4ae0-8f78-1046c0fef1ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total_hash_size:  524288\n"
     ]
    }
   ],
   "source": [
    "pos_encoder = ToyHashEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b4fce2c8-1919-4ada-8bc0-0aeef6bfc1bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "xyz = torch.rand(8192, 3)\n",
    "h = pos_encoder(xyz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6e7f1248-c9fd-443e-9959-55ec7089186d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "h shape:  torch.Size([8192, 2])\n"
     ]
    }
   ],
   "source": [
    "print(\"h shape: \", h.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ff88a51a-0c3e-4785-9dd4-f981ad610bf3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pytorch total\n",
      " -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                               _module_functionBackward        32.40%       9.339ms        56.30%      16.228ms       1.623ms      10.140ms        34.93%      16.274ms       1.627ms            10  \n",
      "                                       _module_function        33.99%       9.796ms        34.69%       9.998ms     999.800us       9.825ms        33.84%      10.045ms       1.004ms            10  \n",
      "                                            aten::fill_        22.32%       6.433ms        22.32%       6.433ms     214.433us       5.769ms        19.87%       5.769ms     192.300us            30  \n",
      "                                             aten::add_         1.77%     511.000us         1.77%     511.000us      51.100us     534.000us         1.84%     534.000us      53.400us            10  \n",
      "                                              aten::mul         0.99%     285.000us         0.99%     285.000us       9.500us     351.000us         1.21%     351.000us      11.700us            30  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 28.824ms\n",
      "Self CUDA time total: 29.031ms\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2023-03-29 20:15:14 2437593:2437593 ActivityProfilerController.cpp:300] Completed Stage: Warm Up\n",
      "STAGE:2023-03-29 20:15:14 2437593:2437593 ActivityProfilerController.cpp:306] Completed Stage: Collection\n",
      "STAGE:2023-03-29 20:15:14 2437593:2437593 ActivityProfilerController.cpp:310] Completed Stage: Post Processing\n"
     ]
    }
   ],
   "source": [
    "repeat = 10\n",
    "with torch.autograd.profiler.profile(use_cuda=True) as prof:\n",
    "    for _ in range(repeat):\n",
    "        # check forward\n",
    "        h = pos_encoder(xyz)\n",
    "        loss = ((h * h) - torch.tanh(h)).sum()\n",
    "        # Backward\n",
    "        loss.backward()\n",
    "\n",
    "print(\n",
    "    'pytorch total\\n',\n",
    "    prof.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total', row_limit=5)\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cb70872-63f5-4f7d-bdf8-d08c00735b9e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch_nightly]",
   "language": "python",
   "name": "conda-env-torch_nightly-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
